package com.ideaaedi.mybatis.data.security.util;

import com.baomidou.mybatisplus.annotation.TableName;
import com.baomidou.mybatisplus.core.conditions.AbstractWrapper;
import com.baomidou.mybatisplus.core.conditions.Wrapper;
import com.baomidou.mybatisplus.core.conditions.segments.MergeSegments;
import com.baomidou.mybatisplus.core.conditions.segments.NormalSegmentList;
import com.baomidou.mybatisplus.core.enums.SqlKeyword;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.baomidou.mybatisplus.core.toolkit.Constants;
import com.baomidou.mybatisplus.extension.conditions.AbstractChainWrapper;
import com.ideaaedi.mybatis.data.security.annotation.Encrypt;
import com.ideaaedi.mybatis.data.security.support.EncryptInfoHolder;
import com.ideaaedi.mybatis.data.security.support.EncryptParser;
import com.ideaaedi.mybatis.data.security.support.PojoCloneable;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.ibatis.mapping.MappedStatement;
import org.springframework.util.CollectionUtils;
import org.springframework.util.ReflectionUtils;
import sun.reflect.generics.repository.ClassRepository;

import java.io.Serializable;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

/**
 * mybatis-plus支持工具类
 *
 * @author <font size = "20" color = "#3CAA3C"><a href="https://gitee.com/JustryDeng">JustryDeng</a></font> <img
 * src="https://gitee.com/JustryDeng/shared-files/raw/master/JustryDeng/avatar.jpg" />
 * @since 1.0.0
 */
public final class MybatisPlusSupportUtil {
    
    /** 泛型左符号 */
    private static final String GENERIC_LEFT_SIGN = "<";
    
    /** 泛型右符号 */
    private static final String GENERIC_RIGHT_SIGN = ">";
    
    /** java.lang.Class#getGenericInfo() */
    private static final Method GET_GENERIC_INFO;
    
    /** java.lang.reflect.Method#getGenericSignature() */
    private static final Method GET_GENERIC_SIGNATURE;
    
    /** com.baomidou.mybatisplus.core.conditions.AbstractWrapper#entity */
    private static final Field ENTITY;
    
    static {
        GET_GENERIC_SIGNATURE = ReflectionUtils.findMethod(Method.class, "getGenericSignature");
        //noinspection ConstantConditions
        GET_GENERIC_SIGNATURE.setAccessible(true);
        
        GET_GENERIC_INFO = ReflectionUtils.findMethod(Class.class, "getGenericInfo");
        //noinspection ConstantConditions
        GET_GENERIC_INFO.setAccessible(true);
    
        ENTITY = ReflectionUtils.findField(AbstractWrapper.class, "entity");
        //noinspection ConstantConditions
        ENTITY.setAccessible(true);
    }
    
    /**
     * clazz是否是mybatis-plus对应的数据库表对应的数据模型
     *
     * @param clazz
     *            待判断的clazz
     * @return  是否是mybatis-plus中BaseMapper里的方法
     */
    public static boolean isDbEntity(Class<?> clazz) {
        if (clazz == null) {
            return false;
        }
        List<Annotation> list = new ArrayList<>(16);
        while (clazz != Object.class) {
            list.addAll(Arrays.asList(clazz.getAnnotations()));
            clazz = clazz.getSuperclass();
        }
        return list.stream().anyMatch(anno -> Objects.equals(anno.annotationType(), TableName.class));
    }
    
    /**
     * 加密warrapper<T>中相关的数据
     *
     * @param mappedStatement
     *            sql对应的mappedStatement实例
     * @param wrapper
     *            对象（可能包装有DbEntity相关数据）
     * @param encryptBeanInfoList
     *            所有需要加的类的信息
     * @param encryptOop
     *            加密器
     * @return  warrapper<T>对象
     */
    @SuppressWarnings("rawtypes")
    public static Wrapper encryptWrapper(MappedStatement mappedStatement, @SuppressWarnings("rawtypes") Wrapper wrapper,
                                         List<EncryptInfoHolder.BeanEncryptDetailInfo> encryptBeanInfoList,
                                         EncryptParser.EncryptOop encryptOop,
                                         @SuppressWarnings("rawtypes") Map<PojoCloneable, List<PojoCloneable>> originPojoAndClonePojoListMap) {
        Class<?> dbEntityClass = MybatisPlusSupportUtil.parseMybatisPlusDbEntityClass(mappedStatement);
        EncryptInfoHolder.BeanEncryptDetailInfo beanEncryptDetailInfo = encryptBeanInfoList.stream()
                .filter(x -> x.getBeanClass() == dbEntityClass).findFirst().orElse(null);
        if (beanEncryptDetailInfo != null) {
            Map<String, Pair<Field, Encrypt>> columnEncryptMap = beanEncryptDetailInfo.getColumnEncryptMap();
            // entity需要加密
            Object entity = wrapper.getEntity();
            AbstractWrapper abstractWrapper = null;
            if (entity instanceof PojoCloneable) {
                abstractWrapper = assertAbstractMapper(wrapper);
                entity = encryptOop.pojoEncrypt(mappedStatement, entity, encryptBeanInfoList, originPojoAndClonePojoListMap);
                try {
                    ENTITY.set(abstractWrapper, entity);
                } catch (IllegalAccessException e) {
                    throw new RuntimeException("invoke java.lang.reflect.Field.set error.", e);
                }
            } else {
                encryptOop.pojoEncrypt(mappedStatement, entity, encryptBeanInfoList, originPojoAndClonePojoListMap);
            }
            
            // wrapper里面的需要加密
            if (abstractWrapper == null) {
                abstractWrapper = assertAbstractMapper(wrapper);
            }
            if (abstractWrapper == null) {
                return wrapper;
            }
            // 加载一下sql，使得所有的参数(set参数 & where后的条件参数)都放能出现进paramNameValuePairs
            wrapper.getSqlSegment();
            //noinspection unchecked
            Map<String, Object> paramNameValuePairs = abstractWrapper.getParamNameValuePairs();
            if (CollectionUtils.isEmpty(paramNameValuePairs)) {
                return wrapper;
            }
        
            /*
             * 用于记录paramNameValuePairs中，需要加密的项的 key以及该加密项对应的数据库表列名
             */
            Map<String, String> paramNameValuePairsNeedEncryptKey = new LinkedHashMap<>();
            AtomicInteger idx = new AtomicInteger(0);
            // => 处理set部分（sqlSet值形如：name=#{ew.paramNameValuePairs.MPGENVAL1},gender=#{ew.paramNameValuePairs.MPGENVAL2}）
            String sqlSet = wrapper.getSqlSet();
            if (StringUtils.isNotBlank(sqlSet)) {
                Arrays.stream(sqlSet.split(","))
                        .filter(StringUtils::isNotBlank)
                        .forEach(str -> {
                            String columnName = str.split("=")[0].trim();
                            if (columnEncryptMap.containsKey(columnName)) {
                                paramNameValuePairsNeedEncryptKey.put(Constants.WRAPPER_PARAM + idx.incrementAndGet(), columnName);
                            } else {
                                idx.incrementAndGet();
                            }
                        });
            }
            // => 处理where部分
            MergeSegments expression = wrapper.getExpression();
            if (expression != null) {
                // 处理normal
                NormalSegmentList normal = expression.getNormal();
                int size = normal.size();
                for (int i = 0; i <= size - 1; i++) {
                    boolean currSegmentIsColumnName = i + 1 < size && normal.get(i + 1) instanceof SqlKeyword;
                    if (currSegmentIsColumnName) {
                        String columnName = normal.get(i).getSqlSegment();
                        columnName = columnName.replace("`", "");
                        boolean existNeedEncryptColumn = columnEncryptMap.containsKey(columnName);
                        boolean existValueSegment = i + 2 < size;
                        if (existNeedEncryptColumn && existValueSegment) {
                            String sqlSegment = normal.get(i + 2).getSqlSegment();
                            // 类似于in， 一个sql片段中可能会有多个占位符。 如： (#{ew.paramNameValuePairs.MPGENVAL1},#{ew.paramNameValuePairs.MPGENVAL2})
                            int count = StringUtils.countMatches(sqlSegment, Constants.WRAPPER_PARAM);
                            for (int j = 0; j < count; j++) {
                                paramNameValuePairsNeedEncryptKey.put(Constants.WRAPPER_PARAM + idx.incrementAndGet(), columnName);
                            }
                        } else {
                            idx.incrementAndGet();
                        }
                    }
                }
            }
        
            Map<String, Object> tmpMap = new HashMap<>(16);
            for (Map.Entry<String, Object> entry : paramNameValuePairs.entrySet()) {
                String key = entry.getKey();
                Object value = entry.getValue();
                if (value instanceof String && paramNameValuePairsNeedEncryptKey.containsKey(key)) {
                    Pair<Field, Encrypt> fieldEncrypt = columnEncryptMap.get(paramNameValuePairsNeedEncryptKey.get(key));
                    Objects.requireNonNull(fieldEncrypt, "fieldEncrypt cannot be null.");
                    Field field = fieldEncrypt.getKey();
                    // 加密并替换
                    value = encryptOop.stringEncrypt(field.getName(), value.toString(), fieldEncrypt.getValue(), wrapper);
                }
                tmpMap.put(key, value);
            }
            // 替换为密文的
            paramNameValuePairs.putAll(tmpMap);
        }
        return wrapper;
    }
    
    /**
     * 筛选出mappedStatement对应的方法
     *
     * @param mappedStatement
     *            sql的MappedStatement实例
     * @param targetMethodList
     *            根据mappedStatement。getId()里面的类名和方法名匹配出来的Method,但因为可能重载，所以可能有多个
     * @return  过滤掉不满足校验的方法后，剩下的满足校验方法
     */
    public static List<Method> filterMethod(MappedStatement mappedStatement, List<Method> targetMethodList) {
        /*
         * 专门处理mybatis-plus的重载方法
         * 如：在mybatis-plus 3.5.1版本中就存在重载方法
         * com.baomidou.mybatisplus.core.mapper.BaseMapper.deleteById(java.io.Serializable)
         * com.baomidou.mybatisplus.core.mapper.BaseMapper.deleteById(T)
         */
        List<Method> methods = new ArrayList<>(targetMethodList);
        int size = targetMethodList.size();
        if (size > 1) {
            Class<?> parameterType = mappedStatement.getParameterMap().getType();
            boolean isDbEntity = isDbEntity(parameterType);
            Class<?> targetParamClass = isDbEntity ? Object.class : Serializable.class;
            methods = methods.stream().filter(m -> {
                for (Class<?> paramClass : m.getParameterTypes()) {
                    if (targetParamClass == paramClass) {
                        return true;
                    }
                }
                return false;
            }).collect(Collectors.toList());
        }
        return methods;
    }
    
    /**
     * method是否是mybatis-plus中BaseMapper里的方法
     *
     * @param method
     *            方法
     * @return  是否是mybatis-plus中BaseMapper里的方法
     */
    public static boolean isMybatisPlusMethod(Method method) {
        return Objects.equals(method.getDeclaringClass(), BaseMapper.class);
    }
    
    /**
     * 方法的参数是否包含要加密的T
     * <br />
     * 主要用于识别参数中是否有： Wrapper<T> 或者 T
     *
     * @param method
     *            目标方法
     * @return  参数否是包含T
     */
    public static boolean parameterContainDbEntity(Method method) {
        for (Parameter parameter : method.getParameters()) {
            Class<?> parameterType = parameter.getType();
            // Wrapper<T>对应com.baomidou.mybatisplus.core.conditions.Wrapper
            if (Wrapper.class.equals(parameterType)) {
                return true;
            }
            // T对应java.lang.Object
            if (Object.class.equals(parameterType)) {
                return true;
            }
        }
        return false;
    }
    
    /**
     * 判断object是否是Wrapper类
     * <br />
     *
     * @param object
     *            判断的对象
     * @return  object是否是Wrapper类
     */
    public static boolean isWrapperSubclass(Object object) {
        return object instanceof Wrapper;
    }

    /**
     * 方法的返回类型中是否包含T
     * <br/>
     * 样例：
     *
     * @param method
     *            返回值类型
     * @return  返回值中是否包含T
     */
    public static boolean returnClassContainDbEntity(Method method) {
        Object getGenericSignature = ReflectionUtils.invokeMethod(GET_GENERIC_SIGNATURE, method);
        if (getGenericSignature == null) {
            return false;
        }
        // 返回值没有明确的指定泛型，但由于参数存在泛型，导致getGenericSignature不为null的情况，如上面的示例2
        String signature = getGenericSignature.toString();
        /*
         * <P extends IPage<T>> P selectPage(P page, @Param(Constants.WRAPPER) Wrapper<T> queryWrapper); 的返回值也需要解密，其签名是：
         * <P::Lcom/baomidou/mybatisplus/core/metadata/IPage<TT;>;>(TP;Lcom/baomidou/mybatisplus/core/conditions/Wrapper<TT;>;)TP;
         */
        if (signature.contains("IPage<TT;>")) {
            return true;
        }
        String returnObjSignature = signature.substring(signature.lastIndexOf(")") + 1);
        if (returnObjSignature.contains("<")) {
            int startIndex = returnObjSignature.indexOf(GENERIC_LEFT_SIGN);
            int endIndex = returnObjSignature.lastIndexOf(GENERIC_RIGHT_SIGN);
            if (startIndex == -1 || endIndex == -1) {
                return false;
            }
            String genericInfo = returnObjSignature.substring(startIndex + 1, endIndex);
            return genericInfo.contains("TT");
        } else {
            return returnObjSignature.contains("TT");
        }
    }
    
    /**
     * 解析mappedStatement对应的数据模型的class
     *
     * <ol>
     *     <li>假设mappedStatement对应的sql是com.ideaaedi.demo.mapper.PaymentMapper.selectList</li>
     *     <li>第一步：解析出PaymentMapper的父类为com.baomidou.mybatisplus.core.mapper.BaseMapper<com.ideaaedi.demo.entity.PaymentPO></li>
     *     <li>第二步：解析出数据模型com.ideaaedi.demo.entity.PaymentPO</li>
     *     <li>第三步：返回数据模型com.ideaaedi.demo.entity.PaymentPO的class对象</li>
     * </ol>
     *
     * @param mappedStatement
     *            sql的MappedStatement实例
     * @return  mappedStatement对应的数据模型的class
     */
    public static Class<?> parseMybatisPlusDbEntityClass(MappedStatement mappedStatement) {
        try {
            String mappedStatementId = mappedStatement.getId();
            int lastIdx = mappedStatementId.lastIndexOf(".");
            // public interface MyMapper extends BaseMapper<Product>, 这里获取到MyMapper类
            Class<?> realClass = Class.forName(mappedStatementId.substring(0, lastIdx));
            ClassRepository classRepository = (ClassRepository) ReflectionUtils.invokeMethod(GET_GENERIC_INFO, realClass);
            assert classRepository != null;
            // 获取到形如'com.baomidou.mybatisplus.core.mapper.BaseMapper<cn.onehome.flex.core.db.entity.OrderDetailPayment>'这样的字符串
            String baseMapperWithGeneric = Arrays.stream(classRepository.getSuperInterfaces()).map(Type::getTypeName)
                    .filter(x -> x.startsWith(BaseMapper.class.getName()))
                    .findFirst().orElseThrow(
                            () -> new IllegalStateException("Cannot find superInterface '" + BaseMapper.class.getName() + "' subclass is "
                                    + "'" + realClass + "'")
                    );
            int startIdx = baseMapperWithGeneric.indexOf("<");
            int endIdx = baseMapperWithGeneric.lastIndexOf(">");
            return Class.forName(baseMapperWithGeneric.substring(startIdx + 1, endIdx));
        } catch (Exception e) {
            throw new IllegalStateException(String.format("Parse mybatis-plus exception. Determine db-entity class fail. curr mappedStatement is '%s'.", mappedStatement.getId()), e);
        }
    }
    
    /**
     * wrapper转换为AbstractWrapper
     */
    @SuppressWarnings("rawtypes")
    private static AbstractWrapper assertAbstractMapper(Wrapper wrapper) {
        AbstractWrapper abstractWrapper;
        if (wrapper instanceof AbstractWrapper) {
            //noinspection rawtypes
            abstractWrapper = (AbstractWrapper) wrapper;
        } else if (wrapper instanceof AbstractChainWrapper) {
            //noinspection rawtypes
            abstractWrapper = ((AbstractChainWrapper) wrapper).getWrapper();
        } else {
            throw new IllegalStateException("un-support wrapper type for " + wrapper.getClass().getCanonicalName());
        }
        return abstractWrapper;
    }
}
